#############################################################
# Author: Mike Burnham, mlb6496@psu.edu
# Python: 3.11.5
# OS: Windows 10
#
# Notes: This script creates Figures 2 and 4. To create 
# figure 3 see sect_5_analysis.R
##############################################################

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

############
## Figure 2
############

# import data
mcc = pd.read_csv('./mcc_matrix.csv')
mcc_notrump = pd.read_csv('./mcc_matrix_notrump.csv')
# subset to columns included in the plot
mcc = mcc[mcc['approach'].isin(['Logit Bias', 'NLI', 'Supervised'])]
mcc_notrump = mcc_notrump[mcc_notrump['approach'].isin(['Logit Bias', 'NLI', 'Supervised'])]
# Combine the two dataframes for easy plotting
df_combined = pd.concat([mcc, mcc_notrump], keys=['Trump Mentioned', 'Trump Not Mentioned'])
# subset to desired models
df_combined = df_combined[df_combined['model'].isin(['GPT-4', 'Mistral', 'DeBERTa', 'PoliBERTweet', 'RoBERTa'])]
# rename models for plotting
df_combined.replace({'GPT-4': 'GPT-4\n(In-Context)', 
                    'Mistral':'Mistral 7B\n(In-Context)', 
                    'DeBERTa':'DeBERTa\n(NLI)',
                    'PoliBERTweet':'PoliBERTweet\n(Supervised)',
                    'RoBERTa':'RoBERTa\n(Supervised)'
                   }, inplace = True)

# Create a point plot with confidence intervals
fig, ax = plt.subplots(figsize=(12, 8))

sns.stripplot(x='mcc', y='model', data=df_combined, hue=df_combined.index.get_level_values(0), palette = ['black', 'white'], edgecolor = 'black', linewidth = .75,
              size = 8, dodge=True)

x_coords = []
y_coords = []
for point_pair in ax.collections:
    for x, y in point_pair.get_offsets():
        x_coords.append(x)
        y_coords.append(y)
# add error bars
ax.errorbar(x=x_coords, y=y_coords, xerr=[df_combined['mcc'] - df_combined['lower'], df_combined['upper'] - df_combined['mcc']],
             fmt='o', color='black', capsize=5, capthick=1)

# add line for human coder performance
ax.axvline(x=0.71, color='black', linestyle='--', label='Human Coders')

ax.set(xlim = 0)
ax.set_xlabel('MCC', fontsize=16)
ax.set_ylabel(None)
ax.tick_params(axis='both', which='major', labelsize=14)

# Adjust the legend and layout
plt.legend(title='', fontsize = 14)
plt.tight_layout()
# export
plt.savefig('./Figure_2.png', dpi = 300)

############
## Figure 4
############
zs1 = pd.read_csv('./threatmin1.csv')
zs2 = pd.read_csv('./threatmin2.csv')

gray = sns.color_palette("Greys")[3]

fig, axes = plt.subplots(1,3, sharex=True, figsize = (12,6))

sns.violinplot(ax=axes[0], y=zs1[zs1['non_comp'] == 1]['ideology'], color=gray)
axes[0].set_ylabel('Ideology', fontsize=18)
axes[0].set_xlabel('Supervised', fontsize=18)

sns.violinplot(ax=axes[1], y=zs1[zs1['threatmin'] == 1]['ideology'], color=gray)
axes[1].set_ylabel('')  # Remove y-axis label
axes[1].set_xlabel('NLI Hypothesis Set 1', fontsize=18)
axes[1].tick_params(labelleft=False, left = False)

sns.violinplot(ax=axes[2], y=zs2[zs2['threatmin'] == 1]['ideology'], color=gray)
axes[2].set_ylabel(None)  # Remove y-axis label
axes[2].set_xlabel('NLI Hypothesis Set 1', fontsize=18)
axes[2].tick_params(labelleft=False, left = False)
fig.tight_layout()
fig.subplots_adjust(wspace=0)
fig.savefig('./Figure_4.png', dpi= 200)